{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/milvus-io/bootcamp/blob/master/bootcamp/RAG/advanced_rag/hyde_with_langchain.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# HyDE\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Google Colab preparation[optional]\n",
    "This is an optional step, if you want to run this notebook on Google Colab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "! git clone https://github.com/milvus-io/bootcamp.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import shutil\n",
    "src_dir = \"./bootcamp/bootcamp/RAG/advanced_rag/rag_utils\"\n",
    "dst_dir = \"./rag_utils\"\n",
    "shutil.copytree(src_dir, dst_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "! pip install --upgrade langchain langchain-community langchain_milvus langchain-openai bs4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "Please prepare you [OPENAI_API_KEY](https://openai.com/index/openai-api/) in your environment variables.\n",
    "![](imgs/colab_api_key1.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "### If you are running this notebook on Google Colab, you have to restart this session by `Cmd/Ctrl + M`, then press `.` to make the environment take effect."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from google.colab import userdata\n",
    "import os\n",
    "\n",
    "os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "----\n",
    "## Get started\n",
    "![](imgs/hyde.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Prepare the data\n",
    "\n",
    "We use the Langchain WebBaseLoader to load documents from [blog sources](https://lilianweng.github.io/posts/2023-06-23-agent/) and split them into chunks using the RecursiveCharacterTextSplitter."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import bs4\n",
    "from langchain_community.document_loaders import WebBaseLoader\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "\n",
    "from rag_utils.vanilla import vectorstore\n",
    "\n",
    "# Create a WebBaseLoader instance to load documents from web sources\n",
    "loader = WebBaseLoader(\n",
    "    web_paths=(\"https://lilianweng.github.io/posts/2023-06-23-agent/\",),\n",
    "    bs_kwargs=dict(\n",
    "        parse_only=bs4.SoupStrainer(\n",
    "            class_=(\"post-content\", \"post-title\", \"post-header\")\n",
    "        )\n",
    "    ),\n",
    ")\n",
    "# Load documents from web sources using the loader\n",
    "documents = loader.load()\n",
    "# Initialize a RecursiveCharacterTextSplitter for splitting text into chunks\n",
    "text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)\n",
    "\n",
    "# Split the documents into chunks using the text_splitter\n",
    "docs = text_splitter.split_documents(documents)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Build the chain\n",
    "\n",
    "We load the docs into milvus vectorstore, and build a milvus retriever."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorstore.add_documents(docs)\n",
    "retriever = vectorstore.as_retriever()"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Build the vanilla RAG chain."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from rag_utils.vanilla import format_docs, rag_prompt, llm\n",
    "\n",
    "vanilla_rag_chain = (\n",
    "    {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
    "    | rag_prompt\n",
    "    | llm\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Build a hyde chain."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from rag_utils.hyde import HydeRetriever\n",
    "\n",
    "hyde_retriever = HydeRetriever.from_vectorstore(vectorstore)\n",
    "\n",
    "hyde_chain = (\n",
    "    {\"context\": hyde_retriever | format_docs, \"question\": RunnablePassthrough()}\n",
    "    | rag_prompt\n",
    "    | llm\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Test the chain\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[vanilla_result]:\n",
      "The approximate nearest neighbors (ANN) algorithm is commonly used for vector approximate searching in a vector store. Some specific implementations of ANN algorithms include FAISS (Facebook AI Similarity Search), ScaNN (Scalable Nearest Neighbors), LSH (Locality-Sensitive Hashing), and ANNOY (Approximate Nearest Neighbors Oh Yeah).\n",
      "\n",
      "[hyde_result]:\n",
      "The vector approximate searching algorithms that work in a vector store include Locality-Sensitive Hashing (LSH), ANNOY (Approximate Nearest Neighbors Oh Yeah), FAISS (Facebook AI Similarity Search), ScaNN (Scalable Nearest Neighbors), and HNSW (Hierarchical Navigable Small World). These algorithms are commonly used for fast Maximum Inner Product Search (MIPS) in a vector store database.\n"
     ]
    }
   ],
   "source": [
    "query = \"which vector approximate searching algorithms work in a vector store\"\n",
    "\n",
    "vanilla_result = vanilla_rag_chain.invoke(query)\n",
    "hyde_result = hyde_chain.invoke(query)\n",
    "print(f\"\\n[vanilla_result]:\\n{vanilla_result}\\n\\n[hyde_result]:\\n{hyde_result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "In the [hyde_result], it retrieved with ground truth \"HNSW\", which does not appear in the [vanilla_result].\n",
    "\n",
    "Let's dive deep into the retrieved results to find the reason."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# retriever.invoke(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "![](imgs/hyde_res2.png)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hyde_retriever.invoke(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "![](imgs/hyde_res1.png)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "The hyde_retriever retrieved the document with `HNSW`, while the vanilla_retriever did not. This is because HyDE generates fake documents that may contain various vector approximate searching algorithms like `HNSW`, which makes the retriever's results more accurate."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}